import os
import torch
import argparse
from copy import deepcopy


parser = argparse.ArgumentParser(description='Fix checkpoint')
parser.add_argument('--path', default=None, help='checkpoint directory path')


def fix_state_dict(state_dict):
    prefix = 'module.'
    keys = [key for key in state_dict if key.startswith(prefix)]
    for key in keys:
        sanitized_key = key[len(prefix):]
        state_dict[sanitized_key] = state_dict[key]
        del state_dict[key]

    keys = [key for key in state_dict._metadata if key.startswith(prefix)]
    for key in keys:
        sanitized_key = key[len(prefix):]
        state_dict._metadata[sanitized_key] = state_dict._metadata[key]
        del state_dict._metadata[key]

    return state_dict


def fix_checkpoint(path, filename):
    if os.path.isdir(path):
        checkpoint_file = os.path.join(path, filename)
        if os.path.isfile(checkpoint_file):
            fixed_path = path + '_fixed'
            if not os.path.exists(fixed_path):
                print("creating path '%s'", path)
                os.makedirs(fixed_path)

            print("loading checkpoint '%s'", path)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            fixed_checkpoint = deepcopy(checkpoint)
            fixed_checkpoint['state_dict'] = fix_state_dict(fixed_checkpoint['state_dict'])

            fixed_filename = os.path.join(fixed_path, filename)
            print("saving fixed checkpoint '%s'", fixed_filename)
            torch.save(fixed_checkpoint, fixed_filename)


args = parser.parse_args()
fix_checkpoint(args.path, 'checkpoint.pth.tar')
fix_checkpoint(args.path, 'model_best.pth.tar')

